# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
# pylint: disable=R0917
import logging
from json import JSONDecodeError

from langchain_core.callbacks.base import AsyncCallbackHandler
from langchain_core.language_models import BaseChatModel
from langchain_core.messages.ai import AIMessage
from langchain_core.messages.human import HumanMessage
from langchain_core.messages.tool import ToolMessage
from langchain_core.prompts.chat import ChatPromptTemplate
from langchain_core.runnables.config import RunnableConfig
from langchain_core.tools import BaseTool
from langgraph.graph import StateGraph
from pydantic import BaseModel
from pydantic import Field

from aiq.agent.base import AGENT_CALL_LOG_MESSAGE
from aiq.agent.base import AGENT_LOG_PREFIX
from aiq.agent.base import INPUT_SCHEMA_MESSAGE
from aiq.agent.base import NO_INPUT_ERROR_MESSAGE
from aiq.agent.base import TOOL_NOT_FOUND_ERROR_MESSAGE
from aiq.agent.base import AgentDecision
from aiq.agent.base import BaseAgent

logger = logging.getLogger(__name__)


class ReWOOGraphState(BaseModel):
    """State schema for the ReWOO Agent Graph"""
    task: HumanMessage = Field(default_factory=lambda: HumanMessage(content=""))  # the task provided by user
    plan: AIMessage = Field(
        default_factory=lambda: AIMessage(content=""))  # the plan generated by the planner to solve the task
    steps: AIMessage = Field(
        default_factory=lambda: AIMessage(content=""))  # the steps to solve the task, parsed from the plan
    intermediate_results: dict[str, ToolMessage] = Field(default_factory=dict)  # the intermediate results of each step
    result: AIMessage = Field(
        default_factory=lambda: AIMessage(content=""))  # the final result of the task, generated by the solver


class ReWOOAgentGraph(BaseAgent):
    """Configurable LangGraph ReWOO Agent. A ReWOO Agent performs reasoning by interacting with other objects or tools
    and utilizes their outputs to make decisions. Supports retrying on output parsing errors. Argument
    "detailed_logs" toggles logging of inputs, outputs, and intermediate steps."""

    def __init__(self,
                 llm: BaseChatModel,
                 planner_prompt: ChatPromptTemplate,
                 solver_prompt: ChatPromptTemplate,
                 tools: list[BaseTool],
                 use_tool_schema: bool = True,
                 callbacks: list[AsyncCallbackHandler] | None = None,
                 detailed_logs: bool = False):
        super().__init__(llm=llm, tools=tools, callbacks=callbacks, detailed_logs=detailed_logs)

        logger.debug(
            "%s Filling the prompt variables 'tools' and 'tool_names', using the tools provided in the config.",
            AGENT_LOG_PREFIX)
        tool_names = ",".join([tool.name for tool in tools[:-1]]) + ',' + tools[-1].name  # prevent trailing ","
        if not use_tool_schema:
            tool_names_and_descriptions = "\n".join(
                [f"{tool.name}: {tool.description}"
                 for tool in tools[:-1]]) + "\n" + f"{tools[-1].name}: {tools[-1].description}"  # prevent trailing "\n"
        else:
            logger.debug("%s Adding the tools' input schema to the tools' description", AGENT_LOG_PREFIX)
            tool_names_and_descriptions = "\n".join([
                f"{tool.name}: {tool.description}. {INPUT_SCHEMA_MESSAGE.format(schema=tool.input_schema.model_fields)}"
                for tool in tools[:-1]
            ]) + "\n" + (f"{tools[-1].name}: {tools[-1].description}. "
                         f"{INPUT_SCHEMA_MESSAGE.format(schema=tools[-1].input_schema.model_fields)}")

        self.planner_prompt = planner_prompt.partial(tools=tool_names_and_descriptions, tool_names=tool_names)
        self.solver_prompt = solver_prompt
        self.tools_dict = {tool.name: tool for tool in tools}

        logger.debug("%s Initialized ReWOO Agent Graph", AGENT_LOG_PREFIX)

    def _get_tool(self, tool_name: str):
        try:
            return self.tools_dict.get(tool_name)
        except Exception as ex:
            logger.exception("%s Unable to find tool with the name %s\n%s",
                             AGENT_LOG_PREFIX,
                             tool_name,
                             ex,
                             exc_info=True)
            raise ex

    @staticmethod
    def _get_current_step(state: ReWOOGraphState) -> int:
        steps = state.steps.content
        if len(steps) == 0:
            raise RuntimeError('No steps received in ReWOOGraphState')

        if len(state.intermediate_results) == len(steps):
            # all steps are done
            return -1

        return len(state.intermediate_results)

    @staticmethod
    def _parse_planner_output(planner_output: str) -> AIMessage:

        try:
            steps = json.loads(planner_output)
        except json.JSONDecodeError as ex:
            raise ValueError(f"The output of planner is invalid JSON format: {planner_output}") from ex

        return AIMessage(content=steps)

    @staticmethod
    def _replace_placeholder(placeholder: str, tool_input: str | dict, tool_output: str | dict) -> str | dict:

        # Replace the placeholders in the tool input with the previous tool output
        if isinstance(tool_input, dict):
            for key, value in tool_input.items():
                if value is not None:
                    if value == placeholder:
                        tool_input[key] = tool_output
                    elif placeholder in value:
                        # If the placeholder is part of the value, replace it with the stringified output
                        tool_input[key] = value.replace(placeholder, str(tool_output))

        elif isinstance(tool_input, str):
            tool_input = tool_input.replace(placeholder, str(tool_output))

        else:
            assert False, f"Unexpected type for tool_input: {type(tool_input)}"
        return tool_input

    @staticmethod
    def _parse_tool_input(tool_input: str | dict):

        # If the input is already a dictionary, return it as is
        if isinstance(tool_input, dict):
            logger.debug("%s Tool input is already a dictionary. Use the tool input as is.", AGENT_LOG_PREFIX)
            return tool_input

        # If the input is a string, attempt to parse it as JSON
        try:
            tool_input = tool_input.strip()
            # If the input is already a valid JSON string, load it
            tool_input_parsed = json.loads(tool_input)
            logger.debug("%s Successfully parsed structured tool input", AGENT_LOG_PREFIX)

        except JSONDecodeError:
            try:
                # Replace single quotes with double quotes and attempt parsing again
                tool_input_fixed = tool_input.replace("'", '"')
                tool_input_parsed = json.loads(tool_input_fixed)
                logger.debug(
                    "%s Successfully parsed structured tool input after replacing single quotes with double quotes",
                    AGENT_LOG_PREFIX)

            except JSONDecodeError:
                # If it still fails, fall back to using the input as a raw string
                tool_input_parsed = tool_input
                logger.debug("%s Unable to parse structured tool input. Using raw tool input as is.", AGENT_LOG_PREFIX)

        return tool_input_parsed

    async def planner_node(self, state: ReWOOGraphState):
        try:
            logger.debug("%s Starting the ReWOO Planner Node", AGENT_LOG_PREFIX)

            planner = self.planner_prompt | self.llm
            task = str(state.task.content)
            if not task:
                logger.error("%s No task provided to the ReWOO Agent. Please provide a valid task.", AGENT_LOG_PREFIX)
                return {"result": NO_INPUT_ERROR_MESSAGE}

            plan = await self._stream_llm(
                planner,
                {"task": task},
                RunnableConfig(callbacks=self.callbacks)  # type: ignore
            )

            steps = self._parse_planner_output(str(plan.content))

            if self.detailed_logs:
                agent_response_log_message = AGENT_CALL_LOG_MESSAGE % (task, str(plan.content))
                logger.info("ReWOO agent planner output: %s", agent_response_log_message)

            return {"plan": plan, "steps": steps}

        except Exception as ex:
            logger.exception("%s Failed to call planner_node: %s", AGENT_LOG_PREFIX, ex, exc_info=True)
            raise ex

    async def executor_node(self, state: ReWOOGraphState):
        try:
            logger.debug("%s Starting the ReWOO Executor Node", AGENT_LOG_PREFIX)

            current_step = self._get_current_step(state)
            # The executor node should not be invoked after all steps are finished
            if current_step < 0:
                logger.error("%s ReWOO Executor is invoked with an invalid step number: %s",
                             AGENT_LOG_PREFIX,
                             current_step)
                raise RuntimeError(f"ReWOO Executor is invoked with an invalid step number: {current_step}")

            steps_content = state.steps.content
            if isinstance(steps_content, list) and current_step < len(steps_content):
                step = steps_content[current_step]
                if isinstance(step, dict) and "evidence" in step:
                    step_info = step["evidence"]
                    placeholder = step_info.get("placeholder", "")
                    tool = step_info.get("tool", "")
                    tool_input = step_info.get("tool_input", "")
                else:
                    logger.error("%s Invalid step format at index %s", AGENT_LOG_PREFIX, current_step)
                    return {"intermediate_results": state.intermediate_results}
            else:
                logger.error("%s Invalid steps content or index %s", AGENT_LOG_PREFIX, current_step)
                return {"intermediate_results": state.intermediate_results}

            intermediate_results = state.intermediate_results

            # Replace the placeholder in the tool input with the previous tool output
            for _placeholder, _tool_output in intermediate_results.items():
                _tool_output = _tool_output.content
                # If the content is a list, get the first element which should be a dict
                if isinstance(_tool_output, list):
                    _tool_output = _tool_output[0]
                    assert isinstance(_tool_output, dict)

                tool_input = self._replace_placeholder(_placeholder, tool_input, _tool_output)

            requested_tool = self._get_tool(tool)
            if not requested_tool:
                configured_tool_names = list(self.tools_dict.keys())
                logger.warning(
                    "%s ReWOO Agent wants to call tool %s. In the ReWOO Agent's configuration within the config file,"
                    "there is no tool with that name: %s",
                    AGENT_LOG_PREFIX,
                    tool,
                    configured_tool_names)

                intermediate_results[placeholder] = ToolMessage(content=TOOL_NOT_FOUND_ERROR_MESSAGE.format(
                    tool_name=tool, tools=configured_tool_names),
                                                                tool_call_id=tool)
                return {"intermediate_results": intermediate_results}

            if self.detailed_logs:
                logger.debug("%s Calling tool %s with input: %s", AGENT_LOG_PREFIX, requested_tool.name, tool_input)

            # Run the tool. Try to use structured input, if possible
            tool_input_parsed = self._parse_tool_input(tool_input)
            tool_response = await self._call_tool(requested_tool,
                                                  tool_input_parsed,
                                                  RunnableConfig(callbacks=self.callbacks),
                                                  max_retries=3)

            # ToolMessage only accepts str or list[str | dict] as content.
            # Convert into list if the response is a dict.
            if isinstance(tool_response, dict):
                tool_response = [tool_response]

            tool_response_message = ToolMessage(name=tool, tool_call_id=tool, content=tool_response)

            if self.detailed_logs:
                self._log_tool_response(requested_tool.name, tool_input_parsed, str(tool_response))

            intermediate_results[placeholder] = tool_response_message
            return {"intermediate_results": intermediate_results}

        except Exception as ex:
            logger.exception("%s Failed to call executor_node: %s", AGENT_LOG_PREFIX, ex, exc_info=True)
            raise ex

    async def solver_node(self, state: ReWOOGraphState):
        try:
            logger.debug("%s Starting the ReWOO Solver Node", AGENT_LOG_PREFIX)

            plan = ""
            # Add the tool outputs of each step to the plan
            for step in state.steps.content:
                step_info = step["evidence"]
                placeholder = step_info.get("placeholder", "")
                tool_input = step_info.get("tool_input", "")

                intermediate_results = state.intermediate_results
                for _placeholder, _tool_output in intermediate_results.items():
                    _tool_output = _tool_output.content
                    # If the content is a list, get the first element which should be a dict
                    if isinstance(_tool_output, list):
                        _tool_output = _tool_output[0]
                        assert isinstance(_tool_output, dict)

                    tool_input = self._replace_placeholder(_placeholder, tool_input, _tool_output)

                    placeholder = placeholder.replace(_placeholder, str(_tool_output))

                _plan = step.get("plan")
                tool = step_info.get("tool")
                plan += f"Plan: {_plan}\n{placeholder} = {tool}[{tool_input}]"

            task = str(state.task.content)
            solver_prompt = self.solver_prompt.partial(plan=plan)
            solver = solver_prompt | self.llm

            output_message = await self._stream_llm(solver, {"task": task},
                                                    RunnableConfig(callbacks=self.callbacks))  # type: ignore

            if self.detailed_logs:
                solver_output_log_message = AGENT_CALL_LOG_MESSAGE % (task, str(output_message.content))
                logger.info("ReWOO agent solver output: %s", solver_output_log_message)

            return {"result": output_message}

        except Exception as ex:
            logger.exception("%s Failed to call solver_node: %s", AGENT_LOG_PREFIX, ex, exc_info=True)
            raise ex

    async def conditional_edge(self, state: ReWOOGraphState):
        try:
            logger.debug("%s Starting the ReWOO Conditional Edge", AGENT_LOG_PREFIX)

            current_step = self._get_current_step(state)
            if current_step == -1:
                logger.debug("%s The ReWOO Executor has finished its task", AGENT_LOG_PREFIX)
                return AgentDecision.END

            logger.debug("%s The ReWOO Executor is still working on the task", AGENT_LOG_PREFIX)
            return AgentDecision.TOOL

        except Exception as ex:
            logger.exception("%s Failed to determine whether agent is calling a tool: %s",
                             AGENT_LOG_PREFIX,
                             ex,
                             exc_info=True)
            logger.warning("%s Ending graph traversal", AGENT_LOG_PREFIX)
            return AgentDecision.END

    async def _build_graph(self, state_schema):
        try:
            logger.debug("%s Building and compiling the ReWOO Graph", AGENT_LOG_PREFIX)

            graph = StateGraph(state_schema)
            graph.add_node("planner", self.planner_node)
            graph.add_node("executor", self.executor_node)
            graph.add_node("solver", self.solver_node)

            graph.add_edge("planner", "executor")
            conditional_edge_possible_outputs = {AgentDecision.TOOL: "executor", AgentDecision.END: "solver"}
            graph.add_conditional_edges("executor", self.conditional_edge, conditional_edge_possible_outputs)

            graph.set_entry_point("planner")
            graph.set_finish_point("solver")

            self.graph = graph.compile()
            logger.debug("%s ReWOO Graph built and compiled successfully", AGENT_LOG_PREFIX)

            return self.graph

        except Exception as ex:
            logger.exception("%s Failed to build ReWOO Graph: %s", AGENT_LOG_PREFIX, ex, exc_info=ex)
            raise ex

    async def build_graph(self):
        try:
            await self._build_graph(state_schema=ReWOOGraphState)
            logger.debug("%s ReWOO Graph built and compiled successfully", AGENT_LOG_PREFIX)
            return self.graph
        except Exception as ex:
            logger.exception("%s Failed to build ReWOO Graph: %s", AGENT_LOG_PREFIX, ex, exc_info=ex)
            raise ex

    @staticmethod
    def validate_planner_prompt(planner_prompt: str) -> bool:
        errors = []
        if not planner_prompt:
            errors.append("The planner prompt cannot be empty.")
        required_prompt_variables = {
            "{tools}": "The planner prompt must contain {tools} so the planner agent knows about configured tools.",
            "{tool_names}": "The planner prompt must contain {tool_names} so the planner agent knows tool names."
        }
        for variable_name, error_message in required_prompt_variables.items():
            if variable_name not in planner_prompt:
                errors.append(error_message)
        if errors:
            error_text = "\n".join(errors)
            logger.exception("%s %s", AGENT_LOG_PREFIX, error_text)
            raise ValueError(error_text)
        return True

    @staticmethod
    def validate_solver_prompt(solver_prompt: str) -> bool:
        errors = []
        if not solver_prompt:
            errors.append("The solver prompt cannot be empty.")
        if errors:
            error_text = "\n".join(errors)
            logger.exception("%s %s", AGENT_LOG_PREFIX, error_text)
            raise ValueError(error_text)
        return True
